### This file contains impls for MM-DiT, the core model component of SD3

import math
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat

### This file contains impls for underlying related models (CLIP, T5, etc)

import torch, math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast


#################################################################################################
### Core/Utility
#################################################################################################


def attention(q, k, v, heads, mask=None):
    """Convenience wrapper around a basic attention operation"""
    b, _, dim_head = q.shape
    dim_head //= heads
    q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
    return out.transpose(1, 2).reshape(b, -1, heads * dim_head)


class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
        self.act = act_layer
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

def build_mlp(hidden_size, projector_dim, z_dim):
    return nn.Sequential(
                nn.Linear(hidden_size, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, z_dim),
            )

class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding"""
    def __init__(
            self,
            img_size: Optional[int] = 224,
            patch_size: int = 16,
            in_chans: int = 3,
            embed_dim: int = 768,
            flatten: bool = True,
            bias: bool = True,
            strict_img_size: bool = True,
            dynamic_img_pad: bool = False,
            dtype=None,
            device=None,
    ):
        super().__init__()
        self.patch_size = (patch_size, patch_size)
        if img_size is not None:
            self.img_size = (img_size, img_size)
            self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
            self.num_patches = self.grid_size[0] * self.grid_size[1]
        else:
            self.img_size = None
            self.grid_size = None
            self.num_patches = None

        # flatten spatial dim and transpose to channels last, kept for bwd compat
        self.flatten = flatten
        self.strict_img_size = strict_img_size
        self.dynamic_img_pad = dynamic_img_pad

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
        return x


def modulate(x, shift, scale):
    if shift is None:
        shift = torch.zeros_like(scale)
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


#################################################################################
#                   Sine/Cosine Positional Embedding Functions                  #
#################################################################################


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)
    if scaling_factor is not None:
        grid = grid / scaling_factor
    if offset is not None:
        grid = grid - offset
    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0
    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)
    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)
    return np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)


#################################################################################
#               Embedding Layers for Timesteps and Class Labels                 #
#################################################################################


class TimestepEmbedder(nn.Module):
    """Embeds scalar timesteps into vector representations."""

    def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        if torch.is_floating_point(t):
            embedding = embedding.to(dtype=t.dtype)
        return embedding

    def forward(self, t, dtype, **kwargs):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
        t_emb = self.mlp(t_freq)
        return t_emb


class VectorEmbedder(nn.Module):
    """Embeds a flat vector of dimension input_dim"""

    def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


#################################################################################
#                                 Core DiT Model                                #
#################################################################################


def split_qkv(qkv, head_dim):
    qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
    return qkv[0], qkv[1], qkv[2]

def optimized_attention(qkv, num_heads):
    return attention(qkv[0], qkv[1], qkv[2], num_heads)
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn

class SelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        d_state=16,
        d_conv=3,
        expand=2.,
        dt_rank=None,
        conv_bias=True,
        bias=False,
        is_text=False,
        pre_only: bool = False,
        bc_norm: Optional[str] = None,
        dtype=None,
        device=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()



        self.pre_only = pre_only

        self.d_model = dim
        self.d_state = d_state
        if is_text:
            d_conv = 1
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = dt_rank
        if not pre_only:
            self.proj = nn.Linear(self.d_inner, dim, dtype=dtype, device=device)        
        self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
        self.is_text=is_text
        if is_text:
            self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )
        else:
            self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2,
            **factory_kwargs,
        )
            
        self.act = nn.SiLU()

        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
        )
        
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0))  # (K=4, N, inner)
        del self.x_proj

        if bc_norm == "rms":
            self.ln_b = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
            self.ln_c = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
        elif bc_norm == "ln":
            self.ln_b = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
            self.ln_c = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
        elif bc_norm is None:
            self.ln_b = nn.Identity()
            self.ln_c = nn.Identity()
        else:
            raise ValueError(bc_norm)

    def pre_attention(self, x: torch.Tensor,dtw=None):
        
        x = self.in_proj(x)
        B, L, C = x.shape
        K = 4        
        if self.is_text:
            x = x.permute(0, 2,1).contiguous()
            x = self.act(self.conv1d(x))
        else:
            x = x.view(B, int(np.sqrt(L)), int(np.sqrt(L)), C).contiguous()  # [B,H,W,C]
            x = x.permute(0, 3, 1, 2).contiguous()
            x = self.act(self.conv2d(x))

        if not self.is_text:
            x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
            xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (1, 4, 192, 3136)
        else:
            xs = x.view(B, 1, -1, L).repeat(1, K, 1, 1) # (b, k=4, d_inner, l)
        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), dtw)
        xs = xs.view(B, -1, L)
        dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.view(B, K, -1, L)# (b, k, d_state, l)
        Cs = Cs.view(B, K, -1, L) # (b, k, d_state, l)
        Cs = self.ln_c(Cs)
        # .reshape(q.shape[0], q.shape[1], -1)
        Bs = self.ln_b(Bs)
        # .reshape(q.shape[0], q.shape[1], -1)
        return (xs, Bs, Cs, dts)

    def post_attention(self, x: torch.Tensor) -> torch.Tensor:
        assert not self.pre_only
        x = self.proj(x)
        return x

    # def forward(self, x: torch.Tensor) -> torch.Tensor:
    #     (q, k, v) = self.pre_attention(x)
    #     x = attention(q, k, v, self.num_heads)
    #     x = self.post_attention(x)
    #     return x


class RMSNorm(torch.nn.Module):
    def __init__(
        self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
    ):
        """
        Initialize the RMSNorm normalization layer.
        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.
        """
        super().__init__()
        self.eps = eps
        self.learnable_scale = elementwise_affine
        if self.learnable_scale:
            self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
        else:
            self.register_parameter("weight", None)

    def _norm(self, x):
        """
        Apply the RMSNorm normalization to the input tensor.
        Args:
            x (torch.Tensor): The input tensor.
        Returns:
            torch.Tensor: The normalized tensor.
        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.
        Args:
            x (torch.Tensor): The input tensor.
        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.
        """
        x = self._norm(x)
        if self.learnable_scale:
            return x * self.weight.to(device=x.device, dtype=x.dtype)
        else:
            return x


class SwiGLUFeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float] = None,
    ):
        """
        Initialize the FeedForward module.

        Args:
            dim (int): Input dimension.
            hidden_dim (int): Hidden dimension of the feedforward layer.
            multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
            ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.

        Attributes:
            w1 (ColumnParallelLinear): Linear transformation for the first layer.
            w2 (RowParallelLinear): Linear transformation for the second layer.
            w3 (ColumnParallelLinear): Linear transformation for the third layer.

        """
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))


class DismantledBlock(nn.Module):
    """A DiT block with gated adaptive layer norm (adaLN) conditioning."""

    ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")

    def __init__(
        self,
        hidden_size: int=768,
        # num_heads: int,
        mlp_ratio: float = 4.0,
        attn_mode: str = "xformers",
        # qkv_bias: bool = False,
        pre_only: bool = False,
        rmsnorm: bool = False,
        scale_mod_only: bool = False,
        swiglu: bool = False,
        bc_norm: Optional[str] = None,
        d_state=16,
        expand=2.,
        dt_rank="auto",
        dtype=None,
        device=None,
        is_text: bool = False,
        **block_kwargs,
    ):
        super().__init__()
        assert attn_mode in self.ATTENTION_MODES
        if not rmsnorm:
            self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
        else:
            self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = SelfAttention(dim=hidden_size,d_state=d_state,expand=expand,dt_rank=dt_rank, pre_only=pre_only, bc_norm=bc_norm, is_text=is_text, dtype=dtype, device=device)
        if not pre_only:
            if not rmsnorm:
                self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
            else:
                self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        if not pre_only:
            if not swiglu:
                self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device)
            else:
                self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
        self.scale_mod_only = scale_mod_only
        if not scale_mod_only:
            n_mods = 6 if not pre_only else 2
        else:
            n_mods = 4 if not pre_only else 1
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))
        self.pre_only = pre_only

    def pre_attention(self, x: torch.Tensor, c: torch.Tensor,dtw=None):
        assert x is not None, "pre_attention called with None input"
        if not self.pre_only:
            if not self.scale_mod_only:
                shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
            else:
                shift_msa = None
                shift_mlp = None
                scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
            xbcd = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa),dtw=dtw)
            return xbcd, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
        else:
            if not self.scale_mod_only:
                shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
            else:
                shift_msa = None
                scale_msa = self.adaLN_modulation(c)
            xbcd = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa),dtw=dtw)
            return xbcd, None

    def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
        assert not self.pre_only
        x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x


def block_mixing(context, x, context_block, x_block, c,a=None, d=None, dtw=None, dtb=None,d_state=None):
    assert context is not None, "block_mixing called with None context"
    context_qkv, context_intermediates = context_block.pre_attention(context, c, dtw)

    x_qkv, x_intermediates = x_block.pre_attention(x, c,dtw)
    H = W = int(np.sqrt(x_qkv[0].shape[-1]))
    LX = x_qkv[0].shape[-1]
    K = 4  # number of blocks
    o = []
    for t in range(4):
        o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=-1))
    xs, bs, cs, dts = tuple(o)
    B ,_ , L= xs.shape
    d = d.float().view(-1)
    a = -torch.exp(a.float()).view(-1, d_state)
    dtb = dtb.float().view(-1) # (k * d)
    out_y = selective_scan_fn(
            xs, dts,
            a, bs, cs, d, z=None,
            delta_bias=dtb,
            delta_softplus=True,
            return_last_state=False,
        ).view(B, K, -1, L)

    # attn = attention(q, k, v, x_block.attn.num_heads)
    context_attn, x_attn = (out_y[:, : , :, : context_qkv[0].shape[-1]], out_y[:, :,:,context_qkv[0].shape[-1] :])
    
    inv_y = torch.flip(x_attn[:, 2:4], dims=[-1]).view(B, 2, -1, LX)
    wh_y = torch.transpose(x_attn[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, LX)
    invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, LX)
    x_attn = x_attn[:, 0]+ inv_y[:, 0]+ wh_y+ invwh_y
    x_attn = torch.transpose(x_attn, dim0=1, dim1=2).contiguous().view(B, LX, -1)
    summed = context_attn.sum(dim=1)
    context_attn = summed.transpose(1, 2)
    
    if not context_block.pre_only:
        context = context_block.post_attention(context_attn, *context_intermediates)
    else:
        context = None
    x = x_block.post_attention(x_attn, *x_intermediates)
    return context, x


class JointBlock(nn.Module):
    """just a small wrapper to serve as a fsdp unit"""

    def __init__(self, *args, **kwargs):
        
        super().__init__()

        pre_only = kwargs.pop("pre_only")
        bc_norm = kwargs.pop("bc_norm", None)
        dtype = kwargs.pop("dtype", None)
        device = kwargs.pop("device", None)
        self.d_model = kwargs.pop("hidden_size", 768)
        factory_kwargs = {"device": device, "dtype": dtype}
        d_state=16
        expand= 2.
        dt_rank="auto"
        dt_min=0.001
        dt_max=0.1
        dt_init="random"
        dt_scale=1.0
        dt_init_floor=1e-4
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.d_state=d_state
        self.expand=expand
        self.dt_min=dt_min
        self.dt_max=dt_max
        self.dt_scale=dt_scale
        self.dt_init_floor=dt_init_floor
        self.d_inner = int(self.expand * self.d_model)
        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
        )
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0))  # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0))  # (K=4, inner)
        del self.dt_projs
        
        self.context_block = DismantledBlock(*args,hidden_size=self.d_model,d_state=d_state,expand=expand,dt_rank=self.dt_rank, pre_only=pre_only, bc_norm=bc_norm,is_text=True, **kwargs)
        self.x_block = DismantledBlock(*args,hidden_size=self.d_model,d_state=d_state,expand=expand,dt_rank=self.dt_rank, pre_only=False, bc_norm=bc_norm, **kwargs)

        self.A_logs = self.A_log_init(d_state, self.d_inner, copies=4, merge=True)  # (K=4*D,N)
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True)  # (K=4*D)
        
    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
                **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = dt_rank ** -0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True

        return dt_proj
    
    @staticmethod
    def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
        # S4D real initialization
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        if copies > 1:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def D_init(d_inner, copies=1, device=None, merge=True):
        # D "skip" parameter
        D = torch.ones(d_inner, device=device)
        if copies > 1:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)  # Keep in fp32
        D._no_weight_decay = True
        return D
    
    def forward(self, *args, **kwargs):
        return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, a=self.A_logs,d=self.Ds,dtw=self.dt_projs_weight,dtb=self.dt_projs_bias,d_state=self.d_state,**kwargs)


class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """

    def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
        self.linear = (
            nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
            if (total_out_channels is None)
            else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
        )
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class mmdimam(nn.Module):
    """Diffusion model with a Transformer backbone."""

    def __init__(
        self,
        input_size: int = 32,
        patch_size: int = 2,
        in_channels: int = 4,
        depth: int = 24,
        hidden_size=768,
        mlp_ratio: float = 4.0,
        learn_sigma: bool = False,
        adm_in_channels: Optional[int] = None,
        context_embedder_config: Optional[Dict] = None,
        register_length: int = 0,
        attn_mode: str = "torch",
        rmsnorm: bool = False,
        scale_mod_only: bool = False,
        swiglu: bool = False,
        out_channels: Optional[int] = None,
        pos_embed_scaling_factor: Optional[float] = None,
        pos_embed_offset: Optional[float] = None,
        pos_embed_max_size: Optional[int] = None,
        num_patches = None,
        bc_norm: Optional[str] = None,
        qkv_bias: bool = True,
        dtype = None,
        device = None,
        encoder_depth = 8,
        z_dims=[768],
        projector_dim=2048,
        use_wave=True,
        is_rag=False,
        use_projector=True,
        mip=None
    ):
        super().__init__()
        print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {bc_norm=}, {qkv_bias=}, {dtype=}, {device=}")
        self.dtype = dtype
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        default_out_channels = in_channels * 2 if learn_sigma else in_channels
        self.out_channels = out_channels if out_channels is not None else default_out_channels
        self.patch_size = patch_size
        self.pos_embed_scaling_factor = pos_embed_scaling_factor
        self.pos_embed_offset = pos_embed_offset
        self.pos_embed_max_size = pos_embed_max_size = 8
        self.is_rag = is_rag
        # apply magic --> this defines a head_size of 64
        # hidden_size = 64 * depth
        self.hidden_size = hidden_size
        num_heads = depth
        self.depth = depth
        self.num_heads = num_heads

        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)
        self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)

        if adm_in_channels is not None:
            assert isinstance(adm_in_channels, int)
            self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)

        self.context_embedder = nn.Identity()
        # TODO: hand coded
        context_embedder_config = {"params": {"in_features": 768, "out_features": hidden_size}, "target": "torch.nn.Linear"}
        if context_embedder_config is not None:
            if context_embedder_config["target"] == "torch.nn.Linear":
                self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)

        self.register_length = register_length
        if self.register_length > 0:
            self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))

        num_patches = self.x_embedder.num_patches
        # Will use fixed sin-cos embedding:
        # just use a buffer already
        if num_patches is not None:
            self.register_buffer(
                "pos_embed",
                torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
            )
        else:
            self.pos_embed = None
            
        self.use_wave = use_wave
        from src.models.denoiser.mmdim import DiMBlockCombined, gen_paths
        from mamba_ssm.modules.mamba_simple import CondMamba
        try:
            from mamba_ssm.ops.triton.layernorm import RMSNorm as mamba_RMSNorm
        except ImportError:
            RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
        from functools import partial
        
        if self.use_wave:
            wave_depth = 6
            drop_path=0.1
            
            self.wave_depth = wave_depth
            grid_size = int(math.sqrt(num_patches))
            ssm_cfg=None
            rms_norm=True
            norm_epsilon=1e-5
            scan_type='sweep_4'
            transpose=(scan_type == "none")
            fused_add_norm=True
            residual_in_fp32=True
            scanning_continuity=True
            use_gated_mlp=False
            if scan_type.startswith("zigma") or scan_type.startswith("sweep") or scan_type.startswith("jpeg"):
                block_kwargs = gen_paths(grid_size, scan_type,wave_depth)
            else:
                block_kwargs = {}
            if ssm_cfg is None:
                ssm_cfg = {}
            factory_kwargs = {"device": device, "dtype": dtype}
            norm_cls = partial(nn.LayerNorm if not rms_norm else mamba_RMSNorm, eps=norm_epsilon, **factory_kwargs)
            dpr = [x.item() for x in torch.linspace(0, drop_path, self.depth)]  # stochastic depth decay rule
            inter_dpr = [0.0] + dpr
            def creat_wave(layer_idx):
                reverse=(scan_type == "none") and (layer_idx % 2 > 0)
                mixer_cls = partial(
                    CondMamba,
                    # d_model= hidden_size,
                    layer_idx=layer_idx,
                    # scan_type=scan_type,
                    d_cond=hidden_size,
                    **ssm_cfg,
                    **block_kwargs,
                    **factory_kwargs,
                )
                block = DiMBlockCombined(
                    hidden_size,
                    mixer_cls,
                    norm_cls=norm_cls,
                    drop_path=inter_dpr[layer_idx],
                    fused_add_norm=fused_add_norm,
                    residual_in_fp32=residual_in_fp32,
                    reverse=reverse,
                    transpose=transpose,
                    scanning_continuity=scanning_continuity,
                    use_gated_mlp=use_gated_mlp,
                )
                return block

            self.wave_blocks = nn.ModuleList(
                [
                    creat_wave(i)
                    for i in range(wave_depth)
                ]
            )
            
        self.joint_blocks = nn.ModuleList(
            [
                JointBlock(hidden_size=hidden_size, mlp_ratio=mlp_ratio, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, bc_norm=bc_norm, dtype=dtype, device=device)
                for i in range(depth)
            ]
        )

        self.mip = None
        if self.is_rag:
            self.mip =mip

        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)

        # REPA
        
        self.encoder_depth = encoder_depth
        self.use_projectors = use_projector
        if self.use_projectors:
            self.projectors = nn.ModuleList([
                build_mlp(hidden_size, projector_dim, z_dim) for z_dim in z_dims
                ])

        # Initialize (and freeze) pos_embed by sin-cos embedding:
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)
            )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        self.initialize_weights()

    def initialize_weights(self):

        # Zero-out adaLN modulation layers in DiT blocks:
        if self.use_wave:
            for block in self.wave_blocks:
                nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
                nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

    def cropped_pos_embed(self, hw):
        assert self.pos_embed_max_size is not None
        p = self.x_embedder.patch_size[0]
        h, w = hw
        # patched size
        h = h // p
        w = w // p
        assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
        assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
        top = (self.pos_embed_max_size - h) // 2
        left = (self.pos_embed_max_size - w) // 2
        spatial_pos_embed = rearrange(
            self.pos_embed,
            "1 (h w) c -> 1 h w c",
            h=self.pos_embed_max_size,
            w=self.pos_embed_max_size,
        )
        spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
        spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
        return spatial_pos_embed

    def unpatchify(self, x, hw=None):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0]
        if hw is None:
            h = w = int(x.shape[1] ** 0.5)
        else:
            h, w = hw
            h = h // p
            w = w // p
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
        return imgs

    def forward_core_with_concat(
            self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None,
            detach: Optional[bool] = False,
            inference_params=None,c_wave=None,
            x_clip=None
            ) -> torch.Tensor:
        if self.register_length > 0:
            context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)

        # context is B, L', D
        # x is B, L, D
        B, L, D = x.shape
        wi = 0
        ri = 0
        zs= None
        context_rag = context
        if self.mip is not None:
            ref_img = self.mip.img_linear(x_clip)
            
        residual = None
        for i, block in enumerate(self.joint_blocks):
            context, x = block(context, x, c=c_mod)
            if self.use_projectors and (i + 1) == self.encoder_depth:
                if detach:
                    x_ = x.clone().detach()
                    zs = [projector(x_.reshape(-1, D)).reshape(B, L, -1) for projector in self.projectors]
                else:
                    zs = [projector(x.reshape(-1, D)).reshape(B, L, -1) for projector in self.projectors]
            if self.mip is not None and i >= self.mip.start_rag and i <= self.mip.end_rag:
                x = self.mip.attention_process[ri](x=x, ref_img=ref_img, context=context_rag)
                ri += 1
            if self.use_wave and (self.depth - i) <= self.wave_depth:
                if residual is None:
                    residual = x
                x, residual = self.wave_blocks[wi](x, residual, c=c_wave, inference_params=inference_params)
                wi += 1


        x = self.final_layer(x, c_mod)  # (N, T, patch_size ** 2 * out_channels)
        return x, zs

    def forward(
            self, x: torch.Tensor, t: torch.Tensor,  context: Optional[torch.Tensor] = None,
            detach: Optional[bool] = False,
            x_clip=None
            ) -> torch.Tensor:
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        hw = x.shape[-2:]
        x = self.x_embedder(x) + self.cropped_pos_embed(hw)
        c = self.t_embedder(t, dtype=x.dtype)  # (N, D)
        # if y is not None:
        #     y = self.y_embedder(y)  # (N, D)
        #     c = c + y  # (N, D)
        ys = context.mean(1)
        c_wave = c + ys
        context = self.context_embedder(context)

        x, zs = self.forward_core_with_concat(x, c, context, detach, c_wave=c_wave, x_clip=x_clip)

        x = self.unpatchify(x, hw=hw)  # (N, out_channels, H, W)
        return x, zs